import torch

from model.mobilenetv3 import MobileNetV3_Small_MultiOut, MobileNetV3_Large_MultiOut

model = MobileNetV3_Large_MultiOut()
# model = MobileNetV3_Small(240)
checkpoint = torch.load('sn/epoch_300.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint)
model.eval()
example = torch.rand(1, 3, 120, 120)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("./model_v3.pt")

